{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('../selected-ann.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import KFold, StratifiedKFold\n",
    "import json\n",
    "skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)\n",
    "X = df['content']\n",
    "Y = df['score']\n",
    "i = 0\n",
    "for train_index, test_index in skf.split(X, Y):\n",
    "    train_dict = []\n",
    "    val_dict = []\n",
    "    for index in train_index:\n",
    "        train_dict.append({'content': \"作文：\" + df['content'][index], 'summary': \"评分：\" + str(df['score'][index]) + \"分\"})\n",
    "    for index in test_index:\n",
    "        val_dict.append({'content': \"作文：\" + df['content'][index], 'summary': \"评分：\" + str(df['score'][index]) + \"分\"})\n",
    "    print(len(train_dict), len(val_dict))\n",
    "    train_str = \"\"\n",
    "    val_str = \"\"\n",
    "    for item in train_dict:\n",
    "        train_str += json.dumps(item, ensure_ascii=False) + '\\n'\n",
    "    for item in val_dict:\n",
    "        val_str += json.dumps(item, ensure_ascii=False) + '\\n'\n",
    "    with open('data/train' + str(i) + '.json', 'w', encoding='utf-8') as f:\n",
    "        f.write(train_str)\n",
    "    with open('data/val' + str(i) + '.json', 'w', encoding='utf-8') as f:\n",
    "        f.write(val_str)\n",
    "    i += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"./ChatGLM2-6B/ptuning/output/0-chatglm2-6b-pt-128-2e-2/checkpoint-3000\", trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained(\"./ChatGLM2-6B/ptuning/output/0-chatglm2-6b-pt-128-2e-2/checkpoint-3000\", trust_remote_code=True).half().cuda()\n",
    "model = model.eval()\n",
    "import pandas as pd\n",
    "import json\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data/val0.json\n",
    "import json\n",
    "file = open('data/val0.json', 'r', encoding='utf-8')\n",
    "lines = file.readlines()\n",
    "\n",
    "index = 0\n",
    "for line in lines:\n",
    "    obj = json.loads(line.strip())\n",
    "    content = obj['content']\n",
    "    summary = obj['summary']\n",
    "    print(content)\n",
    "    response, history = model.chat(tokenizer, content, history=[])\n",
    "    print(response)\n",
    "    json.dump({'content': content, 'summary': summary, 'response': response, 'history': history}, open('result/0' + str(index) + '.json', 'a', encoding='utf8'), ensure_ascii=False)\n",
    "    index = index + 1\n",
    "    print(index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response, history = model.chat(tokenizer, \"你好\", history=[])\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# calculate the QWK\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import cohen_kappa_score\n",
    "import json\n",
    "\n",
    "# draw confusion matrix\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "\n",
    "def qwk(seq, step):\n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "    for i in range(600):\n",
    "        y_true.append([])\n",
    "        y_pred.append([])\n",
    "    for index in [0, 1, 2, 3, 4]:\n",
    "        for time in [1, 2, 3, 4, 5]:\n",
    "            filename = \"ChatGLM2-6B/ptuning/output/\" + str(index) + \"-chatglm2-6b-pt-\" + str(seq) + \"-1e-2/\" + str(step) + \"/\" + str(time) + \"/generated_predictions.txt\"\n",
    "            res = open(filename).read()\n",
    "            idx = 0\n",
    "            for r in res.split('\\n'):\n",
    "                if r!=\"\":\n",
    "                    rJSON = json.loads(r)\n",
    "                    labels = rJSON['labels']\n",
    "                    predict = rJSON['predict']\n",
    "                    y_true[120*index + idx].append(int(labels[3]))\n",
    "                    have = False\n",
    "                    for score in [0,1,2,3,4,5]:\n",
    "                        if str(score) in predict:\n",
    "                            y_pred[120*index + idx].append(score)\n",
    "                            have = True\n",
    "                            break\n",
    "                    if not have:\n",
    "                        y_pred[120*index + idx].append(0)\n",
    "                    idx = idx + 1\n",
    "    def avgWithoutMinMax(arr):\n",
    "        return round((sum(arr) - max(arr) - min(arr)) / (len(arr) - 2))\n",
    "\n",
    "    for i in range(len(y_true)):\n",
    "        y_true[i] = avgWithoutMinMax(y_true[i])\n",
    "        y_pred[i] = avgWithoutMinMax(y_pred[i])\n",
    "    # draw confusion matrix\n",
    "    cm = confusion_matrix(y_true, y_pred)\n",
    "    # print(cm)\n",
    "    return cohen_kappa_score(y_true, y_pred, weights='quadratic')\n",
    "\n",
    "for seq in [128, 256, 512, 1024]:\n",
    "    print(seq)\n",
    "    for step in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000]:\n",
    "        print(qwk(seq, step))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.15"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
